package v6

import (
	"fmt"
	"strings"
	"time"

	"github.com/scylladb/go-set/strset"
	"gorm.io/gorm"

	"github.com/anchore/go-logger"
	"github.com/anchore/grype/internal/log"
)

const anyVulnerability = "any"

type VulnerabilityStoreWriter interface {
	AddVulnerabilities(vulns ...*VulnerabilityHandle) error
}

type VulnerabilityStoreReader interface {
	GetVulnerabilities(vuln *VulnerabilitySpecifier, config *GetVulnerabilityOptions) ([]VulnerabilityHandle, error)
}

type GetVulnerabilityOptions struct {
	Preload bool
	Limit   int
}

type VulnerabilitySpecifiers []VulnerabilitySpecifier

type VulnerabilitySpecifier struct {
	// Name of the vulnerability (e.g. CVE-2020-1234)
	Name string

	// ID is the DB ID of the vulnerability
	ID ID

	// Status is the status of the vulnerability (e.g. "active", "rejected", etc.)
	Status VulnerabilityStatus

	// PublishedAfter is a filter to only return vulnerabilities published after the given time
	PublishedAfter *time.Time

	// ModifiedAfter is a filter to only return vulnerabilities modified after the given time
	ModifiedAfter *time.Time

	// IncludeAliases for the given name or ID in results
	IncludeAliases bool

	// Providers
	Providers []string
}

func (v *VulnerabilitySpecifier) String() string {
	var parts []string
	if v.Name != "" {
		parts = append(parts, fmt.Sprintf("name=%s", v.Name))
	}

	if v.ID != 0 {
		parts = append(parts, fmt.Sprintf("id=%d", v.ID))
	}

	if v.Status != "" {
		parts = append(parts, fmt.Sprintf("status=%s", v.Status))
	}

	if v.PublishedAfter != nil {
		parts = append(parts, fmt.Sprintf("publishedAfter=%s", v.PublishedAfter.String()))
	}

	if v.ModifiedAfter != nil {
		parts = append(parts, fmt.Sprintf("modifiedAfter=%s", v.ModifiedAfter.String()))
	}

	if v.IncludeAliases {
		parts = append(parts, "includeAliases=true")
	}

	if len(v.Providers) > 0 {
		parts = append(parts, fmt.Sprintf("providers=%s", strings.Join(v.Providers, ",")))
	}

	if len(parts) == 0 {
		return anyVulnerability
	}

	return fmt.Sprintf("vulnerability(%s)", strings.Join(parts, ", "))
}

func (s VulnerabilitySpecifiers) String() string {
	if len(s) == 0 {
		return anyVulnerability
	}
	var parts []string
	for _, v := range s {
		parts = append(parts, v.String())
	}
	return strings.Join(parts, ", ")
}

func DefaultGetVulnerabilityOptions() *GetVulnerabilityOptions {
	return &GetVulnerabilityOptions{
		Preload: false,
	}
}

type vulnerabilityStore struct {
	db        *gorm.DB
	blobStore *blobStore
}

func newVulnerabilityStore(db *gorm.DB, bs *blobStore) *vulnerabilityStore {
	return &vulnerabilityStore{
		db:        db,
		blobStore: bs,
	}
}

func (s *vulnerabilityStore) AddVulnerabilities(vulnerabilities ...*VulnerabilityHandle) error {
	if err := s.addProviders(s.db, vulnerabilities...); err != nil {
		return fmt.Errorf("unable to add providers: %w", err)
	}
	for i := range vulnerabilities {
		v := vulnerabilities[i]
		// this adds the blob value to the DB and sets the ID on the vulnerability handle
		if err := s.blobStore.addBlobable(v); err != nil {
			return fmt.Errorf("unable to add affected blob: %w", err)
		}

		if v.PublishedDate != nil && v.ModifiedDate == nil {
			// the data here should be consistent, and we are norming around initial publication counts as a modification date.
			// this allows for easily refining queries based on both publication date and modification date without needing
			// to worry about this edge case.
			v.ModifiedDate = v.PublishedDate
		}

		if v.BlobValue != nil {
			aliases := strset.New(v.BlobValue.Aliases...)
			aliases.Remove(v.Name)
			var aliasModels []VulnerabilityAlias
			for _, alias := range aliases.List() {
				aliasModels = append(aliasModels, VulnerabilityAlias{
					Name:  v.Name,
					Alias: alias,
				})
			}
			for _, aliasModel := range aliasModels {
				if err := s.db.FirstOrCreate(&aliasModel).Error; err != nil {
					return err
				}
			}
		}
		if err := createRecordsWithCache(s.db, v); err != nil {
			return err
		}
	}

	return nil
}

func (s *vulnerabilityStore) addProviders(tx *gorm.DB, vulnerabilities ...*VulnerabilityHandle) error { // nolint:dupl
	cacheInst, ok := cacheFromContext(tx.Statement.Context)
	if !ok {
		return fmt.Errorf("unable to fetch provider cache from context")
	}

	var final []*Provider
	byCacheKey := make(map[string][]*Provider)
	for _, v := range vulnerabilities {
		if v.Provider != nil {
			key := v.Provider.cacheKey()
			if existingID, ok := cacheInst.getString(v.Provider); ok {
				// seen in a previous transaction...
				v.ProviderID = existingID
			} else if _, ok := byCacheKey[key]; !ok {
				// not seen within this transaction
				final = append(final, v.Provider)
			}
			byCacheKey[key] = append(byCacheKey[key], v.Provider)
		}
	}

	if len(final) == 0 {
		return nil
	}

	if err := tx.Create(final).Error; err != nil {
		return fmt.Errorf("unable to create provider records: %w", err)
	}

	// update the cache with the new records
	for _, ref := range final {
		cacheInst.set(ref)
	}

	// update all references with the IDs from the cache
	for _, refs := range byCacheKey {
		for _, ref := range refs {
			id, ok := cacheInst.getString(ref)
			if ok {
				ref.setRowID(id)
			}
		}
	}

	// update the parent objects with the FK ID
	for _, p := range vulnerabilities {
		if p.Provider != nil {
			p.ProviderID = p.Provider.ID
		}
	}
	return nil
}

func createRecordsWithCache(tx *gorm.DB, items ...*VulnerabilityHandle) error {
	// look for existing records from the cache, and only create new records
	cacheInst, ok := cacheFromContext(tx.Statement.Context)
	if !ok {
		return fmt.Errorf("cache not found in context")
	}

	// store all entries by their cache key (throw away duplicates)
	skippedRecordsByCacheKey := map[string][]*VulnerabilityHandle{}
	usedKeys := strset.New()
	var finalWrites []*VulnerabilityHandle
	for i := range items {
		p := items[i]
		key := p.cacheKey()

		if usedKeys.Has(key) {
			skippedRecordsByCacheKey[key] = append(skippedRecordsByCacheKey[key], p)
			continue
		}

		if _, ok := skippedRecordsByCacheKey[key]; ok {
			skippedRecordsByCacheKey[key] = append(skippedRecordsByCacheKey[key], p)
			continue
		}
		if _, ok := cacheInst.getID(p); ok {
			skippedRecordsByCacheKey[key] = append(skippedRecordsByCacheKey[key], p)
			continue
		}

		finalWrites = append(finalWrites, p)
		usedKeys.Add(key)
	}

	for i := range finalWrites {
		if err := tx.Omit("Provider").Create(finalWrites[i]).Error; err != nil {
			return fmt.Errorf("unable to create record %#v: %w", finalWrites[i], err)
		}
	}

	// ensure we're always updating the cache with the latest data + the records with any new IDs
	for i := range finalWrites {
		cacheInst.set(finalWrites[i])
	}

	for _, batch := range skippedRecordsByCacheKey {
		for i := range batch {
			id, ok := cacheInst.getID(batch[i])
			if !ok {
				return fmt.Errorf("unable to find ID: %#v", batch[i])
			}
			batch[i].setRowID(id)
		}
	}

	return nil
}

func (s *vulnerabilityStore) GetVulnerabilities(vuln *VulnerabilitySpecifier, config *GetVulnerabilityOptions) ([]VulnerabilityHandle, error) {
	if config == nil {
		config = DefaultGetVulnerabilityOptions()
	}
	fields := logger.Fields{
		"vuln":    vuln,
		"preload": config.Preload,
	}
	start := time.Now()
	var count int
	defer func() {
		fields["duration"] = time.Since(start)
		fields["records"] = count
		log.WithFields(fields).Trace("fetched vulnerability records")
	}()

	var err error
	query := s.db
	if vuln != nil {
		query, err = handleVulnerabilityOptions(s.db, query, *vuln)
		if err != nil {
			return nil, err
		}
	}

	query = s.handlePreload(query, *config)

	var models []VulnerabilityHandle

	var results []*VulnerabilityHandle
	if err := query.FindInBatches(&results, batchSize, func(_ *gorm.DB, _ int) error {
		if config.Preload {
			var blobs []blobable
			for _, r := range results {
				blobs = append(blobs, r)
			}
			if err := s.blobStore.attachBlobValue(blobs...); err != nil {
				return fmt.Errorf("unable to attach vulnerability blobs: %w", err)
			}
		}

		for _, r := range results {
			models = append(models, *r)
		}

		count += len(results)

		if config.Limit > 0 && len(models) >= config.Limit {
			return ErrLimitReached
		}

		return nil
	}).Error; err != nil {
		return models, fmt.Errorf("unable to fetch vulnerability records: %w", err)
	}

	return models, err
}

func (s *vulnerabilityStore) handlePreload(query *gorm.DB, config GetVulnerabilityOptions) *gorm.DB {
	var limitArgs []interface{}
	if config.Limit > 0 {
		query = query.Limit(config.Limit)
		limitArgs = append(limitArgs, func(db *gorm.DB) *gorm.DB {
			return db.Limit(config.Limit)
		})
	}
	if config.Preload {
		query = query.Preload("Provider", limitArgs...)
	}
	return query
}

func handleVulnerabilityOptions(base, parentQuery *gorm.DB, configs ...VulnerabilitySpecifier) (*gorm.DB, error) {
	if len(configs) == 0 {
		return parentQuery, nil
	}

	orConditions := base.Model(&VulnerabilityHandle{})
	var includeAliasJoin bool
	for _, config := range configs {
		query := base.Model(&VulnerabilityHandle{})
		if config.Name != "" {
			if config.IncludeAliases {
				includeAliasJoin = true
				query = query.Where("vulnerability_handles.name = ? collate nocase OR vulnerability_aliases.alias = ?  collate nocase", config.Name, config.Name)
			} else {
				query = query.Where("vulnerability_handles.name = ?  collate nocase", config.Name)
			}
		}

		if config.ID != 0 {
			query = query.Where("vulnerability_handles.id = ?", config.ID)
		}

		if config.PublishedAfter != nil {
			query = query.Where("vulnerability_handles.published_date > ?", *config.PublishedAfter)
		}

		if config.ModifiedAfter != nil {
			query = query.Where("vulnerability_handles.modified_date > ?", *config.ModifiedAfter)
		}

		if config.Status != "" {
			query = query.Where("vulnerability_handles.status = ?", config.Status)
		}

		if len(config.Providers) > 0 {
			query = query.Where("vulnerability_handles.provider_id IN ?", config.Providers)
		}

		orConditions = orConditions.Or(query)
	}

	if includeAliasJoin {
		parentQuery = parentQuery.Joins("LEFT JOIN vulnerability_aliases ON vulnerability_aliases.name = vulnerability_handles.name collate nocase")
	}

	return parentQuery.Where(orConditions), nil
}
